{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example will walk you throught the usage of [(ICML 2024) Meta Probing Agents](https://openreview.net/pdf?id=DwTgy1hXXo)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import json\n",
    "import os\n",
    "\n",
    "import promptbench as pb\n",
    "from promptbench.models import LLMModel\n",
    "from promptbench.dataload import DatasetLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MPA contains two type of agents: the `ParaphraserAgent` and the `EvaluatorAgent`. The `ParaphraserAgent` is used to generate paraphrases of a given evaluation data, while the `EvaluatorAgent` is used to evaluate if the generated paraphrases satisfy certain constraints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The prompts used for paraphrasing and evaluation agents.\n",
    "from promptbench.mpa import MPA_DEFAULT_PRMOPTS\n",
    "\n",
    "# The agents for paraphrasing and evaluation, one paraphraser and one evaluator consists of a pipeline.\n",
    "from promptbench.mpa import ParaphraserAgent, EvaluatorAgent, Pipeline\n",
    "\n",
    "\n",
    "# The input process of different paraphrase rules\n",
    "from promptbench.mpa import ParaphraserBasicInputProcess, ParaphraserQuestionOutputProcess, ParaphraserChoicesOutputProcess\n",
    "\n",
    "# The choice permutation can be implemented without LLMs, so it does not need any agents\n",
    "from promptbench.mpa import ChoicePermuter\n",
    "\n",
    "# The input process of different evaluation rules\n",
    "from promptbench.mpa import EvaluatorMMLUQuestionInputProcess, EvaluatorBasicOutputProcess, EvaluatorMMLUParaphrasedChoicesInputProcess, EvaluatorMMLUNewChoiceInputProcess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. generating a paraphraser and a evaluator agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "api_key = \"\"\n",
    "assert api_key != \"\", \"Please provide an API key\"\n",
    "\n",
    "# The paraphraser and evaluator models,\n",
    "paraphraser = LLMModel(\"gpt-4-turbo\", max_new_tokens=1000, temperature=0.7, api_key=api_key)\n",
    "evaluator = LLMModel(\"gpt-4-turbo\", max_new_tokens=1000, temperature=0, api_key=api_key) \n",
    "   \n",
    "# the storage path for the resulting paraphrased evaluation data\n",
    "results_dir_name = f\"./mpa_results/\"\n",
    "\n",
    "if not os.path.exists(results_dir_name):\n",
    "    os.makedirs(results_dir_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. preparing the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The data format depends on the prompts and the preprocess function of different paraphrase rules. (e.g., ParaphraserBasicInputProcess loaded above)\n",
    "\n",
    "# You can use your own prompts and your own preprocess functions to generate the data for the paraphraser and evaluator\n",
    "\n",
    "# here we provide the datasets used in our paper, you could download it here: https://drive.google.com/drive/folders/14wRz4WyTM5pmmT55QQtEpHE4vGAgXAZ6?usp=sharing\n",
    "with open(\"data/mmlu.json\", 'r') as file:\n",
    "    data = json.load(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. constructing paraphrase and evaluation pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Rule 0: Paraphrase the question\n",
    "\"\"\"\n",
    "paraphrase_question_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"paraphraser_paraphrase_question\"]\n",
    "evaluate_paraphrase_question_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"evaluator_paraphrase_question\"]\n",
    "\n",
    "paraphrase_question_agent = ParaphraserAgent(paraphraser, paraphrase_question_prompt, ParaphraserBasicInputProcess(), ParaphraserQuestionOutputProcess())\n",
    "evaluate_question_agent = EvaluatorAgent(evaluator, evaluate_paraphrase_question_prompt, EvaluatorMMLUQuestionInputProcess(), EvaluatorBasicOutputProcess())\n",
    "\n",
    "# based on the paraphrase question agent and the evaluate question agent, we can create a pipeline\n",
    "# the iteration of the pipeline is 1, which means the pipeline will run once, please refer to the implementation of the pipeline for more details about the iteration\n",
    "paraphrase_question_pipeline = Pipeline(paraphrase_question_agent, evaluate_question_agent, iters=1)\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Rule 1: Add context to the question\n",
    "\"\"\"\n",
    "add_question_context_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"paraphraser_add_question_context\"]\n",
    "evaluate_add_question_context_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"evaluator_add_question_context\"]\n",
    "add_question_context_agent = ParaphraserAgent(paraphraser, add_question_context_prompt, ParaphraserBasicInputProcess(), ParaphraserQuestionOutputProcess())\n",
    "evaluate_add_question_context_agent = EvaluatorAgent(evaluator, evaluate_add_question_context_prompt, EvaluatorMMLUQuestionInputProcess(), EvaluatorBasicOutputProcess())\n",
    "add_question_context_pipeline = Pipeline(add_question_context_agent, evaluate_add_question_context_agent, iters=1)\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Rule 2: Paraphrase the choices\n",
    "\"\"\"\n",
    "paraphrase_choices_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"paraphraser_paraphrase_choices\"]\n",
    "evaluate_paraphrase_choices_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"evaluator_paraphrase_choices\"]\n",
    "paraphrase_choices_agent = ParaphraserAgent(paraphraser, paraphrase_choices_prompt, ParaphraserBasicInputProcess(), ParaphraserChoicesOutputProcess())\n",
    "evaluate_choices_agent = EvaluatorAgent(evaluator, evaluate_paraphrase_choices_prompt, EvaluatorMMLUParaphrasedChoicesInputProcess(), EvaluatorBasicOutputProcess())\n",
    "paraphrase_choices_pipeline = Pipeline(paraphrase_choices_agent, evaluate_choices_agent, iters=1)\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "Rule 3: Add a new choice\n",
    "\"\"\"\n",
    "add_new_choice_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"paraphraser_add_new_choice\"]\n",
    "evaluate_new_choice_prompt = MPA_DEFAULT_PRMOPTS[\"mmlu\"][\"evaluator_add_new_choice\"]\n",
    "add_new_choice_agent = ParaphraserAgent(paraphraser, add_new_choice_prompt, ParaphraserBasicInputProcess(), ParaphraserChoicesOutputProcess())\n",
    "evaluate_new_choice_agent = EvaluatorAgent(evaluator, evaluate_new_choice_prompt, EvaluatorMMLUNewChoiceInputProcess(), EvaluatorBasicOutputProcess())\n",
    "add_new_choice_pipeline = Pipeline(add_new_choice_agent, evaluate_new_choice_agent, iters=1)\n",
    "\n",
    "\"\"\"\n",
    "Rule 4: Permute the choices\n",
    "\"\"\"\n",
    "# This rule does not need any agents, so we can directly use the ChoicePermuter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. start paraphrasing and evaluating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This list will store the paraphrased data with all the rules (0, 1, 2, 3, 4) applied\n",
    "paraphrased_data_0_1_2_3_4 = []\n",
    "# for simplicity, we only paraphrase the first 10 questions\n",
    "data = data[:10]\n",
    "\n",
    "for idx, d in enumerate(data):\n",
    "    \n",
    "    d_0_list = paraphrase_question_pipeline(d)\n",
    "    paraphrase_0 = d_0_list[-1]\n",
    "\n",
    "    d_2_list = paraphrase_choices_pipeline(d)\n",
    "    paraphrase_2 = d_2_list[-1]\n",
    "\n",
    "    d_0_1_list = add_question_context_pipeline(paraphrase_0)\n",
    "    paraphrase_0_1 = d_0_1_list[-1]\n",
    "\n",
    "    d_2_3_list = add_new_choice_pipeline(paraphrase_2)\n",
    "    paraphrase_2_3 = d_2_3_list[-1]\n",
    "\n",
    "    new_choices, new_answer = ChoicePermuter.permute(paraphrase_2_3[\"choices\"], paraphrase_2_3[\"answer\"])\n",
    "    paraphrase_2_3_4 = copy.deepcopy(paraphrase_2_3)\n",
    "    paraphrase_2_3_4[\"choices\"] = new_choices\n",
    "    paraphrase_2_3_4[\"answer\"] = new_answer\n",
    "\n",
    "    paraphrase_0_1_2_3_4 = copy.deepcopy(paraphrase_0_1)\n",
    "    paraphrase_0_1_2_3_4[\"choices\"] = new_choices\n",
    "    paraphrase_0_1_2_3_4[\"answer\"] = new_answer\n",
    "    paraphrased_data_0_1_2_3_4.append(paraphrase_0_1_2_3_4)\n",
    "\n",
    "        \n",
    "    with open(f\"{results_dir_name}/paraphrased_data_0+1+2+3+4.json\", 'w') as file:\n",
    "        json.dump(paraphrased_data_0_1_2_3_4, file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## paraphrasing and evaluating your own data using your own rules\n",
    "\n",
    "1. constructing your own prompts for `paraphraser` and `evaluator`, you could refer to our default prompts in `from promptbench.mpa import MPA_DEFAULT_PRMOPTS`.\n",
    "\n",
    "2. constructing the input process classes and output process classes for `paraphraser` and `evaluator`, you could refer to our default classes in `from promptbench.mpa import ParaphraserBasicInputProcess, ParaphraserQuestionOutputProcess, EvaluatorMMLUQuestionInputProcess, EvaluatorBasicOutputProcess`.\n",
    "\n",
    "3. constructing the `ParaphraserAgent` and `EvaluatorAgent` with your own prompts, input process classes and output process classes.\n",
    "\n",
    "4. start paraphrasing and evaluating your own data using your own rules."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "promptbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
